Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[air] Horovod: Use Torch.encode_data if torch is imported #28440

Merged
merged 6 commits into from
Sep 13, 2022

Conversation

krfricke
Copy link
Contributor

Signed-off-by: Kai Fricke [email protected]

Why are these changes needed?

Horovod with Tune does not work out of the box for GPU checkpoints as they get deserialized on the non-GPU trainer worker, leading to errors. With this PR, we detect if torch is imported and a tensor is supplied in the Horovod backend. If so, we use the torch backend to serialize the data.

Related issue number

Closes #28439

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Copy link
Contributor

@amogkam amogkam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @krfricke, lgtm as a stopgap fix! But ultimately we should refactor the checkpoint encoding/decoding logic out of the Backends and into the framework-specific checkpoints.

Then, when saving a Torch model via TorchCheckpoint.from_model(), the same encode/decode logic will apply regardless of if I'm using TorchTrainer or HorovodTrainer.

Made an issue to track this here: #28462

@@ -190,3 +190,19 @@ def load_torch_model(
f"to be of type `torch.nn.Module`, or a model "
f"state dict of type dict."
)


def contains_tensor(obj):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this? I think if torch is installed, it should be safe to always use the TorchBackend for encoding/decoding (even if the data dict does not contain a tensor). I'm worried in the worst case contains_tensor can lead to a lot of recursion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally didn't have it in, but thought the overhead wouldn't be as bad. But I agree, since it only concerns an internal communication channel and the intermediate objects are not exposed to the user, we can just do this always when torch is loaded. Updated the PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turns out we do need it as torch.save seems to silently fail if a full model is passed (and not a state dict).

I think it should be fine - a similar lookup has to be in pickling after all, and in most cases it should finish early.

This reverts commit 2a21445.

Signed-off-by: Kai Fricke <[email protected]>
This reverts commit 93913af.

Signed-off-by: Kai Fricke <[email protected]>
@krfricke krfricke merged commit 3292ce8 into ray-project:master Sep 13, 2022
@krfricke krfricke deleted the train/horovod-encode-torch branch September 13, 2022 10:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[train] Horvod+Torch does not convert GPU tensors to CPU
2 participants